from Model.InferenceModule.inference_module import InferenceModule
from tianshou.data import Batch
import numpy as np
from Model.InferenceModule.module_utils import trace_log_probs

class PassiveModule(InferenceModule):
    def __init__(self, args, extractor, forward_dist, passive_model, active_model):
        super().__init__(args, extractor)
        self.mp = args.inter
        self.name = "all"
        self.use_active_as_passive = self.mp.use_active_as_passive
        if self.use_active_as_passive:
            self.model = active_model
        else:
            self.model = passive_model
        self.forward_dist = forward_dist
        self.init_optimizer(args)

    def __call__(self, batch, valid, extractor, normalizer, additional=[], grad_settings=[], log_batch=[], keep_all=False):
        omit_flags = self.get_omit(batch, keep_all=keep_all, keep_invalid=True)
        key_state = batch.target
        query_state = batch.obs
        key_state, query_state, valid = key_state[omit_flags], query_state[omit_flags], valid[omit_flags]

        mask = np.eye(extractor.num_objects) if self.use_active_as_passive else None
        params, mask, info = self.model(np.concatenate([key_state, query_state], axis=-1), m=mask,valid = valid, dist_settings=['flat'], ret_settings=additional, grad_settings=grad_settings)
        passive_input, keys, queries, info1,info2 = info
        info = list(zip(info1,info2))
        # if self.use_active_as_passive: params, mask, info = self.model(np.concatenate([key_state, query_state], axis=-1), m=np.eye(extractor.num_objects),valid = valid, dist_settings=['flat'], ret_settings=additional)
        # else: params, mask, info = self.model(query_state, valid=valid, ret_settings=additional)
        
        target, log_probs = self._target_dists(batch, params, omit_flags=omit_flags)
        result = Batch(target=target, params=params, mask=mask, log_probs=log_probs, input=info[0], embed=info[2], omit_flags=omit_flags, passive_input=passive_input) 
        result.trace_log_probs = trace_log_probs(extractor.num_objects, result.log_probs, batch)
        for i, aname in enumerate(additional):
            result[aname] = info[i]
        for k in log_batch:
            result[k] = batch[omit_flags][k]
        return result